{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "81e97727",
   "metadata": {},
   "source": [
    "# Hugging Face x Amazon SageMaker - Asynchronous Inference with Hugging Face's Transformers\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a1644f1",
   "metadata": {},
   "source": [
    "Welcome to this getting started guide. We will use the Hugging Face Inference DLCs and Amazon SageMaker Python SDK to run an [Asynchronous Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html) job.\n",
    "Amazon SageMaker Asynchronous Inference is a new capability in SageMaker that queues incoming requests and processes them asynchronously. Compared to [Batch Transform](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html) [Asynchronous Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html) provides immediate access to the results of the inference job rather than waiting for the job to complete.\n",
    "\n",
    "\n",
    "## How it works \n",
    "\n",
    "Asynchronous inference endpoints have many similarities (and some key differences) compared to real-time endpoints. The process to create asynchronous endpoints is similar to real-time endpoints. You need to create: a model, an endpoint configuration, and an endpoint. However, there are specific configuration parameters specific to asynchronous inference endpoints, which we will explore below.\n",
    "\n",
    "The Invocation of asynchronous endpoints differs from real-time endpoints. Rather than pass the request payload in line with the request, you upload the payload to Amazon S3 and pass an Amazon S3 URI as a part of the request. Upon receiving the request, SageMaker provides you with a token with the output location where the result will be placed once processed. Internally, SageMaker maintains a queue with these requests and processes them. During endpoint creation, you can optionally specify an Amazon SNS topic to receive success or error notifications. Once you receive the notification that your inference request has been successfully processed, you can access the result in the output Amazon S3 location.\n",
    "\n",
    "![architecture](./imgs/e2e.png)\n",
    "\n",
    "\n",
    "_NOTE: You can run this demo in Sagemaker Studio, your local machine, or Sagemaker Notebook Instances_\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5109506",
   "metadata": {},
   "source": [
    "## Development Environment and Permissions\n",
    "\n",
    "### Installation \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c59d90",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e4386d9",
   "metadata": {},
   "source": [
    "### Permissions\n",
    "\n",
    "_If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c22e8d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "import boto3\n",
    "sess = sagemaker.Session()\n",
    "# sagemaker session bucket -> used for uploading data, models and logs\n",
    "# sagemaker will automatically create this bucket if it not exists\n",
    "sagemaker_session_bucket=None\n",
    "if sagemaker_session_bucket is None and sess is not None:\n",
    "    # set to default bucket if a bucket name is not given\n",
    "    sagemaker_session_bucket = sess.default_bucket()\n",
    "\n",
    "try:\n",
    "    role = sagemaker.get_execution_role()\n",
    "except ValueError:\n",
    "    iam = boto3.client('iam')\n",
    "    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
    "\n",
    "sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n",
    "\n",
    "print(f\"sagemaker role arn: {role}\")\n",
    "print(f\"sagemaker bucket: {sess.default_bucket()}\")\n",
    "print(f\"sagemaker session region: {sess.boto_region_name}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff4bef24",
   "metadata": {},
   "source": [
    "## Create Inference `HuggingFaceModel` for the Asynchronous Inference Endpoint\n",
    "\n",
    "We use the [twitter-roberta-base-sentiment](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment) model running our async inference job. This is a RoBERTa-base model trained on ~58M tweets and finetuned for sentiment analysis with the TweetEval benchmark.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "708854cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------!"
     ]
    }
   ],
   "source": [
    "from sagemaker.huggingface.model import HuggingFaceModel\n",
    "from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig\n",
    "from sagemaker.s3 import s3_path_join\n",
    "\n",
    "# Hub Model configuration. <https://huggingface.co/models>\n",
    "hub = {\n",
    "    'HF_MODEL_ID':'cardiffnlp/twitter-roberta-base-sentiment',\n",
    "    'HF_TASK':'text-classification'\n",
    "}\n",
    "\n",
    "# create Hugging Face Model Class\n",
    "huggingface_model = HuggingFaceModel(\n",
    "   env=hub,                      # configuration for loading model from Hub\n",
    "   role=role,                    # iam role with permissions to create an Endpoint\n",
    "   transformers_version=\"4.26\",  # transformers version used\n",
    "   pytorch_version=\"1.13\",        # pytorch version used\n",
    "   py_version='py39',            # python version used\n",
    ")\n",
    "\n",
    "# create async endpoint configuration\n",
    "async_config = AsyncInferenceConfig(\n",
    "    output_path=s3_path_join(\"s3://\",sagemaker_session_bucket,\"async_inference/output\") , # Where our results will be stored\n",
    "    # notification_config={\n",
    "            #   \"SuccessTopic\": \"arn:aws:sns:us-east-2:123456789012:MyTopic\",\n",
    "            #   \"ErrorTopic\": \"arn:aws:sns:us-east-2:123456789012:MyTopic\",\n",
    "    # }, #  Notification configuration\n",
    ")\n",
    "\n",
    "# deploy the endpoint endpoint\n",
    "async_predictor = huggingface_model.deploy(\n",
    "    initial_instance_count=1,\n",
    "    instance_type=\"ml.g4dn.xlarge\",\n",
    "    async_inference_config=async_config\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c187832",
   "metadata": {},
   "source": [
    "We can find our Asynchronous Inference endpoint configuration in the Amazon SageMaker Console. Our endpoint now has type `async` compared to a' real-time' endpoint.\n",
    "\n",
    "![deployed-endpoint](./imgs/deployed_endpoint.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6b3812f",
   "metadata": {},
   "source": [
    "## Request Asynchronous Inference Endpoint using the `AsyncPredictor`\n",
    "\n",
    "The `.deploy()` returns an `AsyncPredictor` object which can be used to request inference. This `AsyncPredictor` makes it easy to send asynchronous requests to your endpoint and get the results back. It has two methods: `predict()` and `predict_async()`. The `predict()` method is synchronous and will block until the inference is complete. The `predict_async()` method is asynchronous and will return immediately with the a `AsyncInferenceResponse`, which can be used to check for the result with polling. If the result object exists in that path, get and return the result."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4317470",
   "metadata": {},
   "source": [
    "### `predict()` request example\n",
    "\n",
    "The `predict()` will upload our `data` to Amazon S3 and run inference against it. Since we are using `predict` it will block until the inference is complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "51c5366b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'label': 'LABEL_2', 'score': 0.8808117508888245}, {'label': 'LABEL_0', 'score': 0.6126593947410583}, {'label': 'LABEL_2', 'score': 0.9425230622291565}, {'label': 'LABEL_0', 'score': 0.5511414408683777}]\n"
     ]
    }
   ],
   "source": [
    "data = {\n",
    "  \"inputs\": [\n",
    "    \"it 's a charming and often affecting journey .\",\n",
    "    \"it 's slow -- very , very slow\",\n",
    "    \"the mesmerizing performances of the leads keep the film grounded and keep the audience riveted .\",\n",
    "    \"the emotions are raw and will strike a nerve with anyone who 's ever had family trauma .\"\n",
    "  ]\n",
    "}\n",
    "\n",
    "res = async_predictor.predict(data=data)\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a65e3fe9",
   "metadata": {},
   "source": [
    "### `predict_async()` request example\n",
    "\n",
    "The `predict_async()` will upload our `data` to Amazon S3 and run inference against it. Since we are using `predict_async` it will return immediately with an `AsyncInferenceResponse` object. \n",
    "In this example, we will loop over a `csv` file and send each line to the endpoint. After that we are going to poll the endpoint until the inference is complete.\n",
    "The provided `tweet_data.csv` contains ~1800 tweets about different airlines.\n",
    "\n",
    "But first, let's do a quick test to see if we can get a result from the endpoint using `predict_async`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "599351f7",
   "metadata": {},
   "source": [
    "#### Single `predict_async()` request example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5d972eec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Response object: <sagemaker.async_inference.async_inference_response.AsyncInferenceResponse object at 0x1047a18b0>\n",
      "Response output path: s3://sagemaker-us-east-1-558105141721/async_inference/output/a3824ec0-ff85-47e4-9843-a84a6b92b25c.out\n",
      "Start Polling to get response:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'label': 'LABEL_2', 'score': 0.973595380783081}]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sagemaker.async_inference.waiter_config import WaiterConfig\n",
    "\n",
    "resp = async_predictor.predict_async(data={\"inputs\": \"i like you. I love you\"})\n",
    "\n",
    "print(f\"Response object: {resp}\")\n",
    "print(f\"Response output path: {resp.output_path}\")\n",
    "print(\"Start Polling to get response:\")\n",
    "\n",
    "config = WaiterConfig(\n",
    "  max_attempts=5, #  number of attempts\n",
    "  delay=10 #  time in seconds to wait between attempts\n",
    "  )\n",
    "\n",
    "resp.get_result(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "991edb25",
   "metadata": {},
   "source": [
    "#### High load `predict_async()` request example using a `csv` file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "fd5768d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All requests sent\n",
      "Output path list length: 1798\n",
      "Output path list sample: s3://sagemaker-us-east-1-558105141721/async_inference/output/2afbf8d1-164e-4c44-864e-5a951cb59adb.out\n",
      "Results length: 1798\n",
      "Results sample: [{'label': 'LABEL_0', 'score': 0.9345051646232605}]\n"
     ]
    }
   ],
   "source": [
    "from csv import reader\n",
    "\n",
    "data_file=\"tweet_data.csv\"\n",
    "\n",
    "output_list = []\n",
    "\n",
    "# open file in read mode\n",
    "with open(data_file, 'r') as csv_reader:\n",
    "    for row in reader(csv_reader):\n",
    "        # send each row as async reuqest request\n",
    "        resp = async_predictor.predict_async(data={\"inputs\": row[0]})\n",
    "        output_list.append(resp)\n",
    "\n",
    "print(\"All requests sent\")    \n",
    "print(f\"Output path list length: {len(output_list)}\")\n",
    "print(f\"Output path list sample: {output_list[26].output_path}\")\n",
    "\n",
    "# iterate over list of output paths and get results\n",
    "results = []\n",
    "for async_response in output_list:\n",
    "    response = async_response.get_result(WaiterConfig())\n",
    "    results.append(response)\n",
    "\n",
    "print(f\"Results length: {len(results)}\")\n",
    "print(f\"Results sample: {results[26]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccc0d04d",
   "metadata": {},
   "source": [
    "## Autoscale (to Zero) the Asynchronous Inference Endpoint\n",
    "\n",
    "Amazon SageMaker supports automatic scaling (autoscaling) your asynchronous endpoint. Autoscaling dynamically adjusts the number of instances provisioned for a model in response to changes in your workload. Unlike other hosted models Amazon SageMaker supports, with Asynchronous Inference, you can also scale down your asynchronous endpoints instances to zero.\n",
    "\n",
    "**Prequistion**: You need to have an running Asynchronous Inference Endpoint up and running. You can check [Create Inference `HuggingFaceModel` for the Asynchronous Inference Endpoint](#create-inference-huggingfacemodel-for-the-asynchronous-inference-endpoint) to see how to create one.\n",
    "\n",
    "If you want to learn more check-out [Autoscale an asynchronous endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-autoscale.html) in the SageMaker documentation.\n",
    "\n",
    "\n",
    "We are going to scale our asynchronous endpoint to 0-5 instances, which means that Amazon SageMaker will scale the endpoint to 0 instances after `600` seconds or 10 minutes to save you cost and scale up to 5 instances in `300` seconds steps getting more than `5.0` invocations.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "66549548",
   "metadata": {},
   "outputs": [],
   "source": [
    "# application-autoscaling client\n",
    "asg_client = boto3.client(\"application-autoscaling\")\n",
    "\n",
    "# This is the format in which application autoscaling references the endpoint\n",
    "resource_id = f\"endpoint/{async_predictor.endpoint_name}/variant/AllTraffic\"\n",
    "\n",
    "# Configure Autoscaling on asynchronous endpoint down to zero instances\n",
    "response = asg_client.register_scalable_target(\n",
    "    ServiceNamespace=\"sagemaker\",\n",
    "    ResourceId=resource_id,\n",
    "    ScalableDimension=\"sagemaker:variant:DesiredInstanceCount\",\n",
    "    MinCapacity=0,\n",
    "    MaxCapacity=5,\n",
    ")\n",
    "\n",
    "response = asg_client.put_scaling_policy(\n",
    "    PolicyName=f'Request-ScalingPolicy-{async_predictor.endpoint_name}',\n",
    "    ServiceNamespace=\"sagemaker\",  \n",
    "    ResourceId=resource_id, \n",
    "    ScalableDimension=\"sagemaker:variant:DesiredInstanceCount\",\n",
    "    PolicyType=\"TargetTrackingScaling\",\n",
    "    TargetTrackingScalingPolicyConfiguration={\n",
    "        \"TargetValue\": 5.0, \n",
    "        \"CustomizedMetricSpecification\": {\n",
    "            \"MetricName\": \"ApproximateBacklogSizePerInstance\",\n",
    "            \"Namespace\": \"AWS/SageMaker\",\n",
    "            \"Dimensions\": [{\"Name\": \"EndpointName\", \"Value\": async_predictor.endpoint_name}],\n",
    "            \"Statistic\": \"Average\",\n",
    "        },\n",
    "        \"ScaleInCooldown\": 600, # duration until scale in begins (down to zero)\n",
    "        \"ScaleOutCooldown\": 300 # duration between scale out attempts\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5277a95e",
   "metadata": {},
   "source": [
    "![scaling](./imgs/scaling.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6af94176",
   "metadata": {},
   "source": [
    "The Endpoint will now scale to zero after 600s. Let's wait until the endpoint is scaled to zero and then test sending requests and measure how long it takes to start an instance to process the requests. We are using the `predict_async()` method to send the request.\n",
    "\n",
    "_**IMPORTANT: Since we defined the `TargetValue` to `5.0` the Async Endpoint will only start to scale out from 0 to 1 if you are sending more than 5 requests within 300 seconds.**_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03201759",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "output_list=[]\n",
    "\n",
    "# send 10 requests\n",
    "for i in range(10):\n",
    "  resp = async_predictor.predict_async(data={\"inputs\": \"it 's a charming and often affecting journey .\"})\n",
    "  output_list.append(resp)\n",
    "\n",
    "# iterate over list of output paths and get results\n",
    "results = []\n",
    "for async_response in output_list:\n",
    "    response = async_response.get_result(WaiterConfig(max_attempts=600))\n",
    "    results.append(response)\n",
    "\n",
    "print(f\"Time taken: {time.time() - start}s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf31fe42",
   "metadata": {},
   "source": [
    "It took about 7-9 minutes to start an instance and to process the requests. This is perfect when you have non real-time critical applications, but want to save money. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6a09b66",
   "metadata": {},
   "source": [
    "![scale-out](./imgs/scale-out.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb10007d",
   "metadata": {},
   "source": [
    "### Delete the async inference endpoint & Autoscaling policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "1e6fb7b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "response = asg_client.deregister_scalable_target(\n",
    "    ServiceNamespace='sagemaker',\n",
    "    ResourceId=resource_id,\n",
    "    ScalableDimension='sagemaker:variant:DesiredInstanceCount'\n",
    ")\n",
    "async_predictor.delete_model()\n",
    "async_predictor.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "c281c456f1b8161c8906f4af2c08ed2c40c50136979eaae69688b01f70e9f4a9"
  },
  "kernelspec": {
   "display_name": "conda_pytorch_p39",
   "language": "python",
   "name": "conda_pytorch_p39"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
